#!/usr/bin/env python3

##
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import copy
from enum import Enum
from typing import Container, Iterable, List, Optional, Set, cast

from avro.errors import AvroRuntimeException
from avro.schema import (
    ArraySchema,
    EnumSchema,
    Field,
    FixedSchema,
    MapSchema,
    NamedSchema,
    RecordSchema,
    Schema,
    UnionSchema,
)


class SchemaType(str, Enum):
    ARRAY = "array"
    BOOLEAN = "boolean"
    BYTES = "bytes"
    DOUBLE = "double"
    ENUM = "enum"
    FIXED = "fixed"
    FLOAT = "float"
    INT = "int"
    LONG = "long"
    MAP = "map"
    NULL = "null"
    RECORD = "record"
    STRING = "string"
    UNION = "union"

    def __str__(self):
        return self.value


class SchemaCompatibilityType(Enum):
    compatible = "compatible"
    incompatible = "incompatible"
    recursion_in_progress = "recursion_in_progress"


class SchemaIncompatibilityType(Enum):
    name_mismatch = "name_mismatch"
    fixed_size_mismatch = "fixed_size_mismatch"
    missing_enum_symbols = "missing_enum_symbols"
    reader_field_missing_default_value = "reader_field_missing_default_value"
    type_mismatch = "type_mismatch"
    missing_union_branch = "missing_union_branch"


PRIMITIVE_TYPES = {
    SchemaType.NULL,
    SchemaType.BOOLEAN,
    SchemaType.INT,
    SchemaType.LONG,
    SchemaType.FLOAT,
    SchemaType.DOUBLE,
    SchemaType.BYTES,
    SchemaType.STRING,
}


class SchemaCompatibilityResult:
    def __init__(
        self,
        compatibility: SchemaCompatibilityType = SchemaCompatibilityType.recursion_in_progress,
        incompatibilities: Optional[List[SchemaIncompatibilityType]] = None,
        messages: Optional[Set[str]] = None,
        locations: Optional[Set[str]] = None,
    ):
        self.locations = locations or {"/"}
        self.messages = messages or set()
        self.compatibility = compatibility
        self.incompatibilities = incompatibilities or []


def merge(this: SchemaCompatibilityResult, that: SchemaCompatibilityResult) -> SchemaCompatibilityResult:
    """
    Merges two {@code SchemaCompatibilityResult} into a new instance, combining the list of Incompatibilities
    and regressing to the SchemaCompatibilityType.incompatible state if any incompatibilities are encountered.
    :param this: SchemaCompatibilityResult
    :param that: SchemaCompatibilityResult
    :return: SchemaCompatibilityResult
    """
    that = cast(SchemaCompatibilityResult, that)
    merged = [*copy(this.incompatibilities), *copy(that.incompatibilities)]
    if this.compatibility is SchemaCompatibilityType.compatible:
        compat = that.compatibility
        messages = that.messages
        locations = that.locations
    else:
        compat = this.compatibility
        messages = this.messages.union(that.messages)
        locations = this.locations.union(that.locations)
    return SchemaCompatibilityResult(
        compatibility=compat,
        incompatibilities=merged,
        messages=messages,
        locations=locations,
    )


CompatibleResult = SchemaCompatibilityResult(SchemaCompatibilityType.compatible)


class ReaderWriter:
    def __init__(self, reader: Schema, writer: Schema) -> None:
        self.reader, self.writer = reader, writer

    def __hash__(self) -> int:
        return id(self.reader) ^ id(self.writer)

    def __eq__(self, other) -> bool:
        if not isinstance(other, ReaderWriter):
            return False
        return self.reader is other.reader and self.writer is other.writer


class ReaderWriterCompatibilityChecker:
    ROOT_REFERENCE_TOKEN = "/"

    def __init__(self):
        self.memoize_map = {}

    def get_compatibility(
        self,
        reader: Schema,
        writer: Schema,
        reference_token: str = ROOT_REFERENCE_TOKEN,
        location: Optional[List[str]] = None,
    ) -> SchemaCompatibilityResult:
        if location is None:
            location = []
        pair = ReaderWriter(reader, writer)
        if pair in self.memoize_map:
            result = cast(SchemaCompatibilityResult, self.memoize_map[pair])
            if result.compatibility is SchemaCompatibilityType.recursion_in_progress:
                result = CompatibleResult
        else:
            self.memoize_map[pair] = SchemaCompatibilityResult()
            result = self.calculate_compatibility(reader, writer, location + [reference_token])
            self.memoize_map[pair] = result
        return result

    # pylSchemaType.INT: disable=too-many-return-statements
    def calculate_compatibility(
        self,
        reader: Schema,
        writer: Schema,
        location: List[str],
    ) -> SchemaCompatibilityResult:
        """
        Calculates the compatibility of a reader/writer schema pair. Will be positive if the reader is capable of reading
        whatever the writer may write
        :param reader: avro.schema.Schema
        :param writer: avro.schema.Schema
        :param location: List[str]
        :return: SchemaCompatibilityResult
        """
        assert reader is not None
        assert writer is not None
        result = CompatibleResult
        if reader.type == writer.type:
            if reader.type in PRIMITIVE_TYPES:
                return result
            if reader.type == SchemaType.ARRAY:
                reader, writer = cast(ArraySchema, reader), cast(ArraySchema, writer)
                return merge(
                    result,
                    self.get_compatibility(reader.items, writer.items, "items", location),
                )
            if reader.type == SchemaType.MAP:
                reader, writer = cast(MapSchema, reader), cast(MapSchema, writer)
                return merge(
                    result,
                    self.get_compatibility(reader.values, writer.values, "values", location),
                )
            if reader.type == SchemaType.FIXED:
                reader, writer = cast(FixedSchema, reader), cast(FixedSchema, writer)
                result = merge(result, check_schema_names(reader, writer, location))
                return merge(result, check_fixed_size(reader, writer, location))
            if reader.type == SchemaType.ENUM:
                reader, writer = cast(EnumSchema, reader), cast(EnumSchema, writer)
                result = merge(result, check_schema_names(reader, writer, location))
                return merge(
                    result,
                    check_reader_enum_contains_writer_enum(reader, writer, location),
                )
            if reader.type == SchemaType.RECORD:
                reader, writer = cast(RecordSchema, reader), cast(RecordSchema, writer)
                result = merge(result, check_schema_names(reader, writer, location))
                return merge(
                    result,
                    self.check_reader_writer_record_fields(reader, writer, location),
                )
            if reader.type == SchemaType.UNION:
                reader, writer = cast(UnionSchema, reader), cast(UnionSchema, writer)
                for i, writer_branch in enumerate(writer.schemas):
                    compat = self.get_compatibility(reader, writer_branch)
                    if compat.compatibility is SchemaCompatibilityType.incompatible:
                        result = merge(
                            result,
                            incompatible(
                                SchemaIncompatibilityType.missing_union_branch,
                                f"reader union lacking writer type: {writer_branch.type.upper()}",
                                location + [str(i)],
                            ),
                        )
                return result
            raise AvroRuntimeException(f"Unknown schema type: {reader.type}")
        if writer.type == SchemaType.UNION:
            writer = cast(UnionSchema, writer)
            for s in writer.schemas:
                result = merge(result, self.get_compatibility(reader, s))
            return result
        if reader.type in {SchemaType.NULL, SchemaType.BOOLEAN, SchemaType.INT}:
            return merge(result, type_mismatch(reader, writer, location))
        if reader.type == SchemaType.LONG:
            if writer.type == SchemaType.INT:
                return result
            return merge(result, type_mismatch(reader, writer, location))
        if reader.type == SchemaType.FLOAT:
            if writer.type in {SchemaType.INT, SchemaType.LONG}:
                return result
            return merge(result, type_mismatch(reader, writer, location))
        if reader.type == SchemaType.DOUBLE:
            if writer.type in {SchemaType.INT, SchemaType.LONG, SchemaType.FLOAT}:
                return result
            return merge(result, type_mismatch(reader, writer, location))
        if reader.type == SchemaType.BYTES:
            if writer.type == SchemaType.STRING:
                return result
            return merge(result, type_mismatch(reader, writer, location))
        if reader.type == SchemaType.STRING:
            if writer.type == SchemaType.BYTES:
                return result
            return merge(result, type_mismatch(reader, writer, location))
        if reader.type in {
            SchemaType.ARRAY,
            SchemaType.MAP,
            SchemaType.FIXED,
            SchemaType.ENUM,
            SchemaType.RECORD,
        }:
            return merge(result, type_mismatch(reader, writer, location))
        if reader.type == SchemaType.UNION:
            reader = cast(UnionSchema, reader)
            for reader_branch in reader.schemas:
                compat = self.get_compatibility(reader_branch, writer)
                if compat.compatibility is SchemaCompatibilityType.compatible:
                    return result
            # No branch in reader compatible with writer
            message = f"reader union lacking writer type {writer.type}"
            return merge(
                result,
                incompatible(SchemaIncompatibilityType.missing_union_branch, message, location),
            )
        raise AvroRuntimeException(f"Unknown schema type: {reader.type}")

    # pylSchemaType.INT: enable=too-many-return-statements

    def check_reader_writer_record_fields(self, reader: RecordSchema, writer: RecordSchema, location: List[str]) -> SchemaCompatibilityResult:
        result = CompatibleResult
        for i, reader_field in enumerate(reader.fields):
            reader_field = cast(Field, reader_field)
            writer_field = lookup_writer_field(writer_schema=writer, reader_field=reader_field)
            if writer_field is None:
                if not reader_field.has_default:
                    if reader_field.type.type == SchemaType.ENUM and reader_field.type.props.get("default"):
                        result = merge(
                            result,
                            self.get_compatibility(
                                reader_field.type,
                                writer,
                                "type",
                                location + ["fields", str(i)],
                            ),
                        )
                    else:
                        result = merge(
                            result,
                            incompatible(
                                SchemaIncompatibilityType.reader_field_missing_default_value,
                                reader_field.name,
                                location + ["fields", str(i)],
                            ),
                        )
            else:
                result = merge(
                    result,
                    self.get_compatibility(
                        reader_field.type,
                        writer_field.type,
                        "type",
                        location + ["fields", str(i)],
                    ),
                )
        return result


def type_mismatch(reader: Schema, writer: Schema, location: List[str]) -> SchemaCompatibilityResult:
    message = f"reader type: {reader.type} not compatible with writer type: {writer.type}"
    return incompatible(SchemaIncompatibilityType.type_mismatch, message, location)


def check_schema_names(reader: NamedSchema, writer: NamedSchema, location: List[str]) -> SchemaCompatibilityResult:
    result = CompatibleResult
    if not schema_name_equals(reader, writer):
        message = f"expected: {writer.fullname}"
        result = incompatible(SchemaIncompatibilityType.name_mismatch, message, location + ["name"])
    return result


def check_fixed_size(reader: FixedSchema, writer: FixedSchema, location: List[str]) -> SchemaCompatibilityResult:
    result = CompatibleResult
    actual = reader.size
    expected = writer.size
    if actual != expected:
        message = f"expected: {expected}, found: {actual}"
        result = incompatible(
            SchemaIncompatibilityType.fixed_size_mismatch,
            message,
            location + ["size"],
        )
    return result


def check_reader_enum_contains_writer_enum(reader: EnumSchema, writer: EnumSchema, location: List[str]) -> SchemaCompatibilityResult:
    result = CompatibleResult
    writer_symbols, reader_symbols = set(writer.symbols), set(reader.symbols)
    extra_symbols = writer_symbols.difference(reader_symbols)
    if extra_symbols:
        default = reader.props.get("default")
        if default and default in reader_symbols:
            result = CompatibleResult
        else:
            result = incompatible(
                SchemaIncompatibilityType.missing_enum_symbols,
                str(extra_symbols),
                location + ["symbols"],
            )
    return result


def incompatible(incompat_type: SchemaIncompatibilityType, message: str, location: List[str]) -> SchemaCompatibilityResult:
    locations = "/".join(location)
    if len(location) > 1:
        locations = locations[1:]
    ret = SchemaCompatibilityResult(
        compatibility=SchemaCompatibilityType.incompatible,
        incompatibilities=[incompat_type],
        locations={locations},
        messages={message},
    )
    return ret


def schema_name_equals(reader: NamedSchema, writer: NamedSchema) -> bool:
    aliases = reader.props.get("aliases")
    return (reader.name == writer.name) or (isinstance(aliases, Container) and writer.fullname in aliases)


def lookup_writer_field(writer_schema: RecordSchema, reader_field: Field) -> Optional[Field]:
    direct = writer_schema.fields_dict.get(reader_field.name)
    if direct:
        return cast(Field, direct)
    aliases = reader_field.props.get("aliases")
    if not isinstance(aliases, Iterable):
        return None
    for alias in aliases:
        writer_field = writer_schema.fields_dict.get(alias)
        if writer_field is not None:
            return cast(Field, writer_field)
    return None
